{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "635d8ebb",
      "metadata": {},
      "source": [
        "# Adaptive RAG\n",
        "\n",
        "- Author: [Teddy](https://github.com/teddylee777)\n",
        "- Design: [Teddy](https://github.com/teddylee777)\n",
        "- Peer Review: [Teddy](https://github.com/teddylee777)\n",
        "- This is a part of [LangChain Open Tutorial](https://github.com/LangChain-OpenTutorial/LangChain-OpenTutorial)\n",
        "\n",
        "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LangChain-OpenTutorial/LangChain-OpenTutorial/blob/main/99-TEMPLATE/00-BASE-TEMPLATE-EXAMPLE.ipynb) [![Open in GitHub](https://img.shields.io/badge/Open%20in%20GitHub-181717?style=flat-square&logo=github&logoColor=white)](https://github.com/LangChain-OpenTutorial/LangChain-OpenTutorial/blob/main/99-TEMPLATE/00-BASE-TEMPLATE-EXAMPLE.ipynb)\n",
        "\n",
        "## Overview\n",
        "\n",
        "이 튜토리얼은 Adaptive RAG(Adaptive Retrieval-Augmented Generation)의 구현을 다룹니다. \n",
        "\n",
        "Adaptive RAG는 쿼리 분석과 능동적/자기 수정 RAG를 결합하여 다양한 데이터 소스에서 정보를 검색하고 생성하는 전략입니다. \n",
        "\n",
        "이 튜토리얼에서는 LangGraph를 사용하여 웹 검색과 자기 수정 RAG 간의 라우팅을 구현합니다.\n",
        "\n",
        "![adaptive-rag](./assets/langgraph-adaptive-rag.png)\n",
        "\n",
        "**Adaptive RAG**는 **RAG**의 전략으로, Query Construction 과 Self-Reflective RAG 을 결합합니다.\n",
        "\n",
        "[논문: Adaptive-RAG: Learning to Adapt Retrieval-Augmented Large Language Models through Question Complexity](https://arxiv.org/abs/2403.14403) 에서는 쿼리 분석을 통해 다음과 같은 라우팅을 수행합니다.\n",
        "\n",
        "- `No Retrieval`\n",
        "- `Single-shot RAG`\n",
        "- `Iterative RAG`\n",
        "\n",
        "이번 튜토리얼에서는 LangGraph를 사용하여 다음과 같은 라우팅을 수행하는 예제를 구현합니다.\n",
        "\n",
        "- **Web Search**: 최신 이벤트와 관련된 질문에 사용\n",
        "- **Self-Reflective RAG**: 인덱스와 관련된 질문에 사용\n",
        "\n",
        "### Table of Contents\n",
        "\n",
        "- [Overview](#overview)\n",
        "- [Environement Setup](#environment-setup)\n",
        "- [Create a basic PDF-based Retrieval Chain](#create-a-basic-pdf-based-retrieval-chain)\n",
        "- [Query routing and document evaluation](#query-routing-and-document-evaluation)\n",
        "- [Tools](#tools)\n",
        "- [Graph Construction](#graph-construction)\n",
        "- [Define Graph Flows](#define-graph-flows)\n",
        "- [Define Nodes](#define-nodes)\n",
        "- [Graph Construction](#graph-construction)\n",
        "- [Execute Graph](#execute-graph)\n",
        "\n",
        "### References\n",
        "\n",
        "- [LangChain: Query Construction](https://blog.langchain.dev/query-construction/)\n",
        "- [LangGraph: Self-Reflective RAG](https://blog.langchain.dev/agentic-rag-with-langgraph/)\n",
        "- [Adaptive-RAG: Learning to Adapt Retrieval-Augmented Large Language Models through Question Complexity](https://arxiv.org/abs/2403.14403)\n",
        "----"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "c6c7aba4",
      "metadata": {},
      "source": [
        "## Environment Setup\n",
        "\n",
        "Set up the environment. You may refer to [Environment Setup](https://wikidocs.net/257836) for more details.\n",
        "\n",
        "**[Note]**\n",
        "- `langchain-opentutorial` is a package that provides a set of easy-to-use environment setup, useful functions and utilities for tutorials. \n",
        "- You can checkout the [`langchain-opentutorial`](https://github.com/LangChain-OpenTutorial/langchain-opentutorial-pypi) for more details."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "id": "21943adb",
      "metadata": {},
      "outputs": [],
      "source": [
        "%%capture --no-stderr\n",
        "!pip install langchain-opentutorial"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "id": "f25ec196",
      "metadata": {},
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\n",
            "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.3.1\u001b[0m\n",
            "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
          ]
        }
      ],
      "source": [
        "# Install required packages\n",
        "from langchain_opentutorial import package\n",
        "\n",
        "package.install(\n",
        "    [\n",
        "        \"langsmith\",\n",
        "        \"langchain\",\n",
        "        \"langchain_core\",\n",
        "        \"langchain-anthropic\",\n",
        "        \"langchain_community\",\n",
        "        \"langchain_text_splitters\",\n",
        "        \"langchain_openai\",\n",
        "    ],\n",
        "    verbose=False,\n",
        "    upgrade=False,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "id": "7f9065ea",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Environment variables have been set successfully.\n"
          ]
        }
      ],
      "source": [
        "# Set environment variables\n",
        "from langchain_opentutorial import set_env\n",
        "\n",
        "set_env(\n",
        "    {\n",
        "        \"OPENAI_API_KEY\": \"\",\n",
        "        \"LANGCHAIN_API_KEY\": \"\",\n",
        "        \"LANGCHAIN_TRACING_V2\": \"true\",\n",
        "        \"LANGCHAIN_ENDPOINT\": \"https://api.smith.langchain.com\",\n",
        "        \"LANGCHAIN_PROJECT\": \"Adaptive-RAG\",  # title 과 동일하게 설정해 주세요\n",
        "    }\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "690a9ae0",
      "metadata": {},
      "source": [
        "You can alternatively set API keys such as `OPENAI_API_KEY` in a `.env` file and load them.\n",
        "\n",
        "[Note] This is not necessary if you've already set the required API keys in previous steps."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "4f99b5b6",
      "metadata": {},
      "outputs": [],
      "source": [
        "# Load API keys from .env file\n",
        "from dotenv import load_dotenv\n",
        "\n",
        "load_dotenv(override=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "616661ad",
      "metadata": {},
      "source": [
        "## 참고 (이미지 파일명 관련)\n",
        "\n",
        "튜토리얼 파일 작성시 `assets` 에 이미지를 추가하여 이를 markdown 으로 추가하는 경우가 있습니다.\n",
        "\n",
        "이때 image 파일명의 통일성을 가져가지 위하여 가이드 전달 드립니다.\n",
        "\n",
        "**이미지 파일명**\n",
        "1. 이미지 파일명은 모두 **영문 소문자**로 작성합니다.\n",
        "2. 이미지 파일에 공백은 없어야 하며 공백 대신 `-` 하이픈으로 대체 합니다.\n",
        "\n",
        "jupyter notebook 파일명 + 이미지 제목 + 필요시 숫자(01, 02, 03, ...)\n",
        "\n",
        "예시)\n",
        "`10-LangGraph-Self-RAG.ipynb`  인 경우\n",
        "\n",
        "이미지 파일명: \n",
        "- `10-langgraph-self-rag-flow-explanation.png` : OK\n",
        "- `10-langgraph-self-rag-flow-explanation-01.png`: OK\n",
        "- `10-langgraph-self-rag-flow-explanation-02.png`: OK"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "17efec71",
      "metadata": {},
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "markdown",
      "id": "aa00c3f4",
      "metadata": {},
      "source": [
        "## Create a basic PDF-based Retrieval Chain\n",
        "\n",
        "여기서는 PDF 문서를 기반으로 Retrieval Chain 을 생성합니다. 가장 단순한 구조의 Retrieval Chain 입니다.\n",
        "\n",
        "단, LangGraph 에서는 Retirever 와 Chain 을 따로 생성합니다. 그래야 각 노드별로 세부 처리를 할 수 있습니다.\n",
        "\n",
        "**참고**\n",
        "- 이전 튜토리얼에서 다룬 내용이므로, 자세한 설명은 생략합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "id": "69cb77da",
      "metadata": {},
      "outputs": [],
      "source": [
        "from rag.pdf import PDFRetrievalChain\n",
        "\n",
        "# PDF 문서를 로드합니다.\n",
        "pdf = PDFRetrievalChain([\"data/SPRI_AI_Brief_2023년12월호_F.pdf\"]).create_chain()\n",
        "\n",
        "# retriever 생성\n",
        "pdf_retriever = pdf.retriever\n",
        "\n",
        "# chain 생성\n",
        "pdf_chain = pdf.chain"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "2b2fc536",
      "metadata": {},
      "source": [
        "## Query routing and document evaluation\n",
        "\n",
        "이번 단계에서는 **쿼리 라우팅** 과 **문서 평가** 를 수행합니다. 이 과정은 **Adaptive RAG** 의 중요한 부분으로, 효율적인 정보 검색과 생성에 기여합니다.\n",
        "\n",
        "- **쿼리 라우팅** : 사용자의 쿼리를 분석하여 적절한 정보 소스로 라우팅합니다. 이를 통해 쿼리의 목적에 맞는 최적의 검색 경로를 설정할 수 있습니다.\n",
        "- **문서 평가** : 검색된 문서의 품질과 관련성을 평가하여 최종 결과의 정확성을 높입니다. \n",
        "\n",
        "이 단계는 **Adaptive RAG** 의 핵심 기능을 지원하며, 정확하고 신뢰할 수 있는 정보 제공을 목표로 합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "id": "1b78d33f",
      "metadata": {},
      "outputs": [],
      "source": [
        "from typing import Literal\n",
        "\n",
        "from langchain_core.prompts import ChatPromptTemplate\n",
        "from pydantic import BaseModel, Field\n",
        "from langchain_openai import ChatOpenAI\n",
        "from langchain_teddynote.models import get_model_name, LLMs\n",
        "\n",
        "# 최신 LLM 모델 이름 가져오기\n",
        "MODEL_NAME = get_model_name(LLMs.GPT4)\n",
        "\n",
        "\n",
        "# 사용자 쿼리를 가장 관련성 높은 데이터 소스로 라우팅하는 데이터 모델\n",
        "class RouteQuery(BaseModel):\n",
        "    \"\"\"Route a user query to the most relevant datasource.\"\"\"\n",
        "\n",
        "    # 데이터 소스 선택을 위한 리터럴 타입 필드\n",
        "    datasource: Literal[\"vectorstore\", \"web_search\"] = Field(\n",
        "        ...,\n",
        "        description=\"Given a user question choose to route it to web search or a vectorstore.\",\n",
        "    )\n",
        "\n",
        "\n",
        "# LLM 초기화 및 함수 호출을 통한 구조화된 출력 생성\n",
        "llm = ChatOpenAI(model=MODEL_NAME, temperature=0)\n",
        "structured_llm_router = llm.with_structured_output(RouteQuery)\n",
        "\n",
        "# 시스템 메시지와 사용자 질문을 포함한 프롬프트 템플릿 생성\n",
        "system = \"\"\"You are an expert at routing a user question to a vectorstore or web search.\n",
        "The vectorstore contains documents related to DEC 2023 AI Brief Report(SPRI) with Samsung Gause, Anthropic, etc.\n",
        "Use the vectorstore for questions on these topics. Otherwise, use web-search.\"\"\"\n",
        "\n",
        "# Routing 을 위한 프롬프트 템플릿 생성\n",
        "route_prompt = ChatPromptTemplate.from_messages(\n",
        "    [\n",
        "        (\"system\", system),\n",
        "        (\"human\", \"{question}\"),\n",
        "    ]\n",
        ")\n",
        "\n",
        "# 프롬프트 템플릿과 구조화된 LLM 라우터를 결합하여 질문 라우터 생성\n",
        "question_router = route_prompt | structured_llm_router"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "c9e4d831",
      "metadata": {},
      "source": [
        "다음은 쿼리 라우팅 결과를 테스트 해본 뒤 결과를 확인합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "0874c14b",
      "metadata": {},
      "outputs": [],
      "source": [
        "# 문서 검색이 필요한 질문\n",
        "print(\n",
        "    question_router.invoke(\n",
        "        {\"question\": \"AI Brief 에서 삼성전자가 만든 생성형 AI 의 이름은?\"}\n",
        "    )\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "a2d22b26",
      "metadata": {},
      "outputs": [],
      "source": [
        "# 웹 검색이 필요한 질문\n",
        "print(question_router.invoke({\"question\": \"판교에서 가장 맛있는 딤섬집 찾아줘\"}))"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "5fc43b99",
      "metadata": {},
      "source": [
        "### 검색 평가기(Retrieval Grader)\n",
        "\n",
        "검색 평가기에 대한 내용..."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "id": "d1221d80",
      "metadata": {},
      "outputs": [],
      "source": [
        "from pydantic import BaseModel, Field\n",
        "from langchain_openai import ChatOpenAI\n",
        "from langchain_core.prompts import ChatPromptTemplate\n",
        "\n",
        "\n",
        "# 문서 평가를 위한 데이터 모델 정의\n",
        "class GradeDocuments(BaseModel):\n",
        "    \"\"\"Binary score for relevance check on retrieved documents.\"\"\"\n",
        "\n",
        "    binary_score: str = Field(\n",
        "        description=\"Documents are relevant to the question, 'yes' or 'no'\"\n",
        "    )\n",
        "\n",
        "\n",
        "# LLM 초기화 및 함수 호출을 통한 구조화된 출력 생성\n",
        "llm = ChatOpenAI(model=MODEL_NAME, temperature=0)\n",
        "structured_llm_grader = llm.with_structured_output(GradeDocuments)\n",
        "\n",
        "# 시스템 메시지와 사용자 질문을 포함한 프롬프트 템플릿 생성\n",
        "system = \"\"\"You are a grader assessing relevance of a retrieved document to a user question. \\n \n",
        "    If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \\n\n",
        "    It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \\n\n",
        "    Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.\"\"\"\n",
        "\n",
        "grade_prompt = ChatPromptTemplate.from_messages(\n",
        "    [\n",
        "        (\"system\", system),\n",
        "        (\"human\", \"Retrieved document: \\n\\n {document} \\n\\n User question: {question}\"),\n",
        "    ]\n",
        ")\n",
        "\n",
        "# 문서 검색결과 평가기 생성\n",
        "retrieval_grader = grade_prompt | structured_llm_grader"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "927cac10",
      "metadata": {},
      "source": [
        "생성한 `retrieval_grader` 를 사용하여 **문서 검색결과** 를 평가합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "id": "2fa5e0d7",
      "metadata": {},
      "outputs": [],
      "source": [
        "# 사용자 질문 설정\n",
        "question = \"삼성전자가 만든 생성형 AI 의 이름은?\"\n",
        "\n",
        "# 질문에 대한 관련 문서 검색\n",
        "docs = pdf_retriever.invoke(question)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ef397b71",
      "metadata": {},
      "outputs": [],
      "source": [
        "# 검색된 문서의 내용 가져오기\n",
        "retrieved_doc = docs[1].page_content\n",
        "\n",
        "# 평가 결과 출력\n",
        "print(retrieval_grader.invoke({\"question\": question, \"document\": retrieved_doc}))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "id": "dce41bfd",
      "metadata": {},
      "outputs": [],
      "source": [
        "# 필터링 하는 코드 예시\n",
        "filtered_docs = []\n",
        "\n",
        "\n",
        "for doc in docs:\n",
        "    # 문서 평가 결과 확인\n",
        "    result = retrieval_grader.invoke(\n",
        "        {\n",
        "            \"question\": question,\n",
        "            \"document\": doc.page_content,\n",
        "        }\n",
        "    )\n",
        "    # 관련성이 있는 문서만 필터링\n",
        "    if result.binary_score == \"yes\":\n",
        "        filtered_docs.append(doc)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "54dce7a1",
      "metadata": {},
      "source": [
        "### 답변 생성을 위한 RAG 체인 생성"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "id": "992ef15a",
      "metadata": {},
      "outputs": [],
      "source": [
        "from langchain import hub\n",
        "from langchain_core.output_parsers import StrOutputParser\n",
        "from langchain_openai import ChatOpenAI\n",
        "\n",
        "# LangChain Hub에서 프롬프트 가져오기(RAG 프롬프트는 자유롭게 수정 가능)\n",
        "prompt = hub.pull(\"teddynote/rag-prompt\")\n",
        "\n",
        "# LLM 초기화\n",
        "llm = ChatOpenAI(model_name=MODEL_NAME, temperature=0)\n",
        "\n",
        "\n",
        "# 문서 포맷팅 함수\n",
        "def format_docs(docs):\n",
        "    return \"\\n\\n\".join(\n",
        "        [\n",
        "            f'<document><content>{doc.page_content}</content><source>{doc.metadata[\"source\"]}</source><page>{doc.metadata[\"page\"]+1}</page></document>'\n",
        "            for doc in docs\n",
        "        ]\n",
        "    )\n",
        "\n",
        "\n",
        "# RAG 체인 생성\n",
        "rag_chain = prompt | llm | StrOutputParser()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "0fbc96e3",
      "metadata": {},
      "source": [
        "이제 생성한 `rag_chain` 에 질문을 전달하여 답변을 생성합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "f8d16e04",
      "metadata": {},
      "outputs": [],
      "source": [
        "# RAG 체인에 질문을 전달하여 답변 생성\n",
        "generation = rag_chain.invoke({\"context\": format_docs(docs), \"question\": question})\n",
        "print(generation)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "a0e9f601",
      "metadata": {},
      "source": [
        "### 답변의 Hallucination 체커 추가"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "id": "40ec0e97",
      "metadata": {},
      "outputs": [],
      "source": [
        "# 할루시네이션 체크를 위한 데이터 모델 정의\n",
        "class GradeHallucinations(BaseModel):\n",
        "    \"\"\"Binary score for hallucination present in generation answer.\"\"\"\n",
        "\n",
        "    binary_score: str = Field(\n",
        "        description=\"Answer is grounded in the facts, 'yes' or 'no'\"\n",
        "    )\n",
        "\n",
        "\n",
        "# 함수 호출을 통한 LLM 초기화\n",
        "llm = ChatOpenAI(model=MODEL_NAME, temperature=0)\n",
        "structured_llm_grader = llm.with_structured_output(GradeHallucinations)\n",
        "\n",
        "# 프롬프트 설정\n",
        "system = \"\"\"You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \\n \n",
        "    Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts.\"\"\"\n",
        "\n",
        "# 프롬프트 템플릿 생성\n",
        "hallucination_prompt = ChatPromptTemplate.from_messages(\n",
        "    [\n",
        "        (\"system\", system),\n",
        "        (\"human\", \"Set of facts: \\n\\n {documents} \\n\\n LLM generation: {generation}\"),\n",
        "    ]\n",
        ")\n",
        "\n",
        "# 환각 평가기 생성\n",
        "hallucination_grader = hallucination_prompt | structured_llm_grader"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "8550b7cf",
      "metadata": {},
      "source": [
        "생성한 `hallucination_grader` 를 사용하여 생성된 답변의 환각 여부를 평가합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "cb593684",
      "metadata": {},
      "outputs": [],
      "source": [
        "# 평가기를 사용하여 생성된 답변의 환각 여부 평가\n",
        "hallucination_grader.invoke({\"documents\": docs, \"generation\": generation})"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "id": "110eb9b0",
      "metadata": {},
      "outputs": [],
      "source": [
        "class GradeAnswer(BaseModel):\n",
        "    \"\"\"Binary scoring to evaluate the appropriateness of answers to questions\"\"\"\n",
        "\n",
        "    binary_score: str = Field(\n",
        "        description=\"Indicate 'yes' or 'no' whether the answer solves the question\"\n",
        "    )\n",
        "\n",
        "\n",
        "# 함수 호출을 통한 LLM 초기화\n",
        "llm = ChatOpenAI(model=MODEL_NAME, temperature=0)\n",
        "structured_llm_grader = llm.with_structured_output(GradeAnswer)\n",
        "\n",
        "# 프롬프트 설정\n",
        "system = \"\"\"You are a grader assessing whether an answer addresses / resolves a question \\n \n",
        "     Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question.\"\"\"\n",
        "answer_prompt = ChatPromptTemplate.from_messages(\n",
        "    [\n",
        "        (\"system\", system),\n",
        "        (\"human\", \"User question: \\n\\n {question} \\n\\n LLM generation: {generation}\"),\n",
        "    ]\n",
        ")\n",
        "\n",
        "# 프롬프트 템플릿과 구조화된 LLM 평가기를 결합하여 답변 평가기 생성\n",
        "answer_grader = answer_prompt | structured_llm_grader"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "66a26ad6",
      "metadata": {},
      "outputs": [],
      "source": [
        "# 평가기를 사용하여 생성된 답변이 질문을 해결하는지 여부 평가\n",
        "answer_grader.invoke({\"question\": question, \"generation\": generation})"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "a9fc11dd",
      "metadata": {},
      "source": [
        "### 쿼리 재작성(Query Rewriter)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "id": "e9df325a",
      "metadata": {},
      "outputs": [],
      "source": [
        "from langchain_openai import ChatOpenAI\n",
        "from langchain_core.prompts import ChatPromptTemplate\n",
        "from langchain_core.output_parsers import StrOutputParser\n",
        "\n",
        "# LLM 초기화\n",
        "llm = ChatOpenAI(model=MODEL_NAME, temperature=0)\n",
        "\n",
        "# Query Rewriter 프롬프트 정의(자유롭게 수정이 가능합니다)\n",
        "system = \"\"\"You a question re-writer that converts an input question to a better version that is optimized \\n \n",
        "for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning.\"\"\"\n",
        "\n",
        "# Query Rewriter 프롬프트 템플릿 생성\n",
        "re_write_prompt = ChatPromptTemplate.from_messages(\n",
        "    [\n",
        "        (\"system\", system),\n",
        "        (\n",
        "            \"human\",\n",
        "            \"Here is the initial question: \\n\\n {question} \\n Formulate an improved question.\",\n",
        "        ),\n",
        "    ]\n",
        ")\n",
        "\n",
        "# Query Rewriter 생성\n",
        "question_rewriter = re_write_prompt | llm | StrOutputParser()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "0abd3e83",
      "metadata": {},
      "source": [
        "생성한 `question_rewriter` 에 질문을 전달하여 개선된 질문을 생성합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "c6eb92e7",
      "metadata": {},
      "outputs": [],
      "source": [
        "# 질문 재작성기에 질문을 전달하여 개선된 질문 생성\n",
        "question_rewriter.invoke({\"question\": question})"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "d8d5ee42",
      "metadata": {},
      "source": [
        "## Tools\n",
        "\n",
        "### 웹 검색 도구\n",
        "\n",
        "**웹 검색 도구**는 **Adaptive RAG**의 중요한 구성 요소로, 최신 정보를 검색하는 데 사용됩니다. 이 도구는 사용자가 최신 이벤트와 관련된 질문에 대해 신속하고 정확한 답변을 얻을 수 있도록 지원합니다.\n",
        "\n",
        "- **설정**: 웹 검색 도구를 설정하여 최신 정보를 검색할 수 있도록 준비합니다.\n",
        "- **검색 수행**: 사용자의 쿼리를 기반으로 웹에서 관련 정보를 검색합니다.\n",
        "- **결과 분석**: 검색된 결과를 분석하여 사용자의 질문에 가장 적합한 정보를 제공합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "id": "e004263c",
      "metadata": {},
      "outputs": [],
      "source": [
        "from langchain_teddynote.tools.tavily import TavilySearch\n",
        "\n",
        "# 웹 검색 도구 생성\n",
        "web_search_tool = TavilySearch(max_results=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "63d60abe",
      "metadata": {},
      "source": [
        "웹 검색 도구를 실행하여 결과를 확인합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "c13be8f3",
      "metadata": {},
      "outputs": [],
      "source": [
        "# 웹 검색 도구 호출\n",
        "result = web_search_tool.search(\"테디노트 위키독스 랭체인 튜토리얼 URL 을 알려주세요\")\n",
        "print(result)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "1904c95c",
      "metadata": {},
      "outputs": [],
      "source": [
        "# 웹 검색 결과의 첫 번째 결과 확인\n",
        "result[0]"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "1ac37855",
      "metadata": {},
      "source": [
        "## Graph Construction"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "70ab91c2",
      "metadata": {},
      "source": [
        "### Defining graph states"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "id": "6d23ab6f",
      "metadata": {},
      "outputs": [],
      "source": [
        "from typing import List\n",
        "from typing_extensions import TypedDict, Annotated\n",
        "\n",
        "\n",
        "# 그래프의 상태 정의\n",
        "class GraphState(TypedDict):\n",
        "    \"\"\"\n",
        "    그래프의 상태를 나타내는 데이터 모델\n",
        "\n",
        "    Attributes:\n",
        "        question: 질문\n",
        "        generation: LLM 생성된 답변\n",
        "        documents: 도큐먼트 리스트\n",
        "    \"\"\"\n",
        "\n",
        "    question: Annotated[str, \"User question\"]\n",
        "    generation: Annotated[str, \"LLM generated answer\"]\n",
        "    documents: Annotated[List[str], \"List of documents\"]"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "f266cc42",
      "metadata": {},
      "source": [
        "## Define Graph Flows\n",
        "\n",
        "**그래프 흐름**을 정의하여 **Adaptive RAG**의 작동 방식을 명확히 합니다. 이 단계에서는 그래프의 상태와 전환을 설정하여 쿼리 처리의 효율성을 높입니다.\n",
        "\n",
        "- **상태 정의**: 그래프의 각 상태를 명확히 정의하여 쿼리의 진행 상황을 추적합니다.\n",
        "- **전환 설정**: 상태 간의 전환을 설정하여 쿼리가 적절한 경로를 따라 진행되도록 합니다.\n",
        "- **흐름 최적화**: 그래프의 흐름을 최적화하여 정보 검색과 생성의 정확성을 향상시킵니다."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "633bf00c",
      "metadata": {},
      "source": [
        "### Define Nodes\n",
        "\n",
        "활용할 노드를 정의합니다.\n",
        "\n",
        "- `retrieve`: 문서 검색 노드\n",
        "- `generate`: 답변 생성 노드\n",
        "- `grade_documents`: 문서 관련성 평가 노드\n",
        "- `transform_query`: 질문 재작성 노드\n",
        "- `web_search`: 웹 검색 노드\n",
        "- `route_question`: 질문 라우팅 노드\n",
        "- `decide_to_generate`: 답변 생성 결정 노드\n",
        "- `hallucination_check`: 환각 평가 노드"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 24,
      "id": "ee6f34d0",
      "metadata": {},
      "outputs": [],
      "source": [
        "from langchain_core.documents import Document\n",
        "\n",
        "\n",
        "# 문서 검색 노드\n",
        "def retrieve(state):\n",
        "    print(\"==== [RETRIEVE] ====\")\n",
        "    question = state[\"question\"]\n",
        "\n",
        "    # 문서 검색 수행\n",
        "    documents = pdf_retriever.invoke(question)\n",
        "    return {\"documents\": documents}\n",
        "\n",
        "\n",
        "# 답변 생성 노드\n",
        "def generate(state):\n",
        "    print(\"==== [GENERATE] ====\")\n",
        "    # 질문과 문서 검색 결과 가져오기\n",
        "    question = state[\"question\"]\n",
        "    documents = state[\"documents\"]\n",
        "\n",
        "    # RAG 답변 생성\n",
        "    generation = rag_chain.invoke({\"context\": documents, \"question\": question})\n",
        "    return {\"generation\": generation}\n",
        "\n",
        "\n",
        "# 문서 관련성 평가 노드\n",
        "def grade_documents(state):\n",
        "    print(\"==== [CHECK DOCUMENT RELEVANCE TO QUESTION] ====\")\n",
        "    # 질문과 문서 검색 결과 가져오기\n",
        "    question = state[\"question\"]\n",
        "    documents = state[\"documents\"]\n",
        "\n",
        "    # 각 문서에 대한 관련성 점수 계산\n",
        "    filtered_docs = []\n",
        "    for d in documents:\n",
        "        score = retrieval_grader.invoke(\n",
        "            {\"question\": question, \"document\": d.page_content}\n",
        "        )\n",
        "        grade = score.binary_score\n",
        "        if grade == \"yes\":\n",
        "            print(\"---GRADE: DOCUMENT RELEVANT---\")\n",
        "            # 관련성이 있는 문서 추가\n",
        "            filtered_docs.append(d)\n",
        "        else:\n",
        "            # 관련성이 없는 문서는 건너뛰기\n",
        "            print(\"---GRADE: DOCUMENT NOT RELEVANT---\")\n",
        "            continue\n",
        "    return {\"documents\": filtered_docs}\n",
        "\n",
        "\n",
        "# 질문 재작성 노드\n",
        "def transform_query(state):\n",
        "    print(\"==== [TRANSFORM QUERY] ====\")\n",
        "    # 질문과 문서 검색 결과 가져오기\n",
        "    question = state[\"question\"]\n",
        "    documents = state[\"documents\"]\n",
        "\n",
        "    # 질문 재작성\n",
        "    better_question = question_rewriter.invoke({\"question\": question})\n",
        "    return {\"question\": better_question}\n",
        "\n",
        "\n",
        "# 웹 검색 노드\n",
        "def web_search(state):\n",
        "    print(\"==== [WEB SEARCH] ====\")\n",
        "    # 질문과 문서 검색 결과 가져오기\n",
        "    question = state[\"question\"]\n",
        "\n",
        "    # 웹 검색 수행\n",
        "    web_results = web_search_tool.invoke({\"query\": question})\n",
        "    web_results_docs = [\n",
        "        Document(\n",
        "            page_content=web_result[\"content\"],\n",
        "            metadata={\"source\": web_result[\"url\"]},\n",
        "        )\n",
        "        for web_result in web_results\n",
        "    ]\n",
        "\n",
        "    return {\"documents\": web_results_docs}\n",
        "\n",
        "\n",
        "# 질문 라우팅 노드\n",
        "def route_question(state):\n",
        "    print(\"==== [ROUTE QUESTION] ====\")\n",
        "    # 질문 가져오기\n",
        "    question = state[\"question\"]\n",
        "    # 질문 라우팅\n",
        "    source = question_router.invoke({\"question\": question})\n",
        "    # 질문 라우팅 결과에 따른 노드 라우팅\n",
        "    if source.datasource == \"web_search\":\n",
        "        print(\"==== [ROUTE QUESTION TO WEB SEARCH] ====\")\n",
        "        return \"web_search\"\n",
        "    elif source.datasource == \"vectorstore\":\n",
        "        print(\"==== [ROUTE QUESTION TO VECTORSTORE] ====\")\n",
        "        return \"vectorstore\"\n",
        "\n",
        "\n",
        "# 문서 관련성 평가 노드\n",
        "def decide_to_generate(state):\n",
        "    print(\"==== [DECISION TO GENERATE] ====\")\n",
        "    # 문서 검색 결과 가져오기\n",
        "    filtered_documents = state[\"documents\"]\n",
        "\n",
        "    if not filtered_documents:\n",
        "        # 모든 문서가 관련성 없는 경우 질문 재작성\n",
        "        print(\n",
        "            \"==== [DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY] ====\"\n",
        "        )\n",
        "        return \"transform_query\"\n",
        "    else:\n",
        "        # 관련성 있는 문서가 있는 경우 답변 생성\n",
        "        print(\"==== [DECISION: GENERATE] ====\")\n",
        "        return \"generate\"\n",
        "\n",
        "\n",
        "def hallucination_check(state):\n",
        "    print(\"==== [CHECK HALLUCINATIONS] ====\")\n",
        "    # 질문과 문서 검색 결과 가져오기\n",
        "    question = state[\"question\"]\n",
        "    documents = state[\"documents\"]\n",
        "    generation = state[\"generation\"]\n",
        "\n",
        "    # 환각 평가\n",
        "    score = hallucination_grader.invoke(\n",
        "        {\"documents\": documents, \"generation\": generation}\n",
        "    )\n",
        "    grade = score.binary_score\n",
        "\n",
        "    # Hallucination 여부 확인\n",
        "    if grade == \"yes\":\n",
        "        print(\"==== [DECISION: GENERATION IS GROUNDED IN DOCUMENTS] ====\")\n",
        "\n",
        "        # 답변의 관련성(Relevance) 평가\n",
        "        print(\"==== [GRADE GENERATED ANSWER vs QUESTION] ====\")\n",
        "        score = answer_grader.invoke({\"question\": question, \"generation\": generation})\n",
        "        grade = score.binary_score\n",
        "\n",
        "        # 관련성 평가 결과에 따른 처리\n",
        "        if grade == \"yes\":\n",
        "            print(\"==== [DECISION: GENERATED ANSWER ADDRESSES QUESTION] ====\")\n",
        "            return \"relevant\"\n",
        "        else:\n",
        "            print(\"==== [DECISION: GENERATED ANSWER DOES NOT ADDRESS QUESTION] ====\")\n",
        "            return \"not relevant\"\n",
        "    else:\n",
        "        print(\"==== [DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY] ====\")\n",
        "        return \"hallucination\""
      ]
    },
    {
      "cell_type": "markdown",
      "id": "2412119d",
      "metadata": {},
      "source": [
        "## Graph Construction\n",
        "\n",
        "**그래프 컴파일** 단계에서는 **Adaptive RAG**의 워크플로우를 구축하고 실행 가능한 상태로 만듭니다. 이 과정은 그래프의 각 노드와 엣지를 연결하여 쿼리 처리의 전체 흐름을 정의합니다.\n",
        "\n",
        "- **노드 정의**: 각 노드를 정의하여 그래프의 상태와 전환을 명확히 합니다.\n",
        "- **엣지 설정**: 노드 간의 엣지를 설정하여 쿼리가 적절한 경로를 따라 진행되도록 합니다.\n",
        "- **워크플로우 구축**: 그래프의 전체 흐름을 구축하여 정보 검색과 생성의 효율성을 극대화합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 26,
      "id": "c106a028",
      "metadata": {},
      "outputs": [],
      "source": [
        "from langgraph.graph import END, StateGraph, START\n",
        "from langgraph.checkpoint.memory import MemorySaver\n",
        "\n",
        "# 그래프 상태 초기화\n",
        "workflow = StateGraph(GraphState)\n",
        "\n",
        "# 노드 정의\n",
        "workflow.add_node(\"web_search\", web_search)  # 웹 검색\n",
        "workflow.add_node(\"retrieve\", retrieve)  # 문서 검색\n",
        "workflow.add_node(\"grade_documents\", grade_documents)  # 문서 평가\n",
        "workflow.add_node(\"generate\", generate)  # 답변 생성\n",
        "workflow.add_node(\"transform_query\", transform_query)  # 쿼리 변환\n",
        "\n",
        "# 그래프 빌드\n",
        "workflow.add_conditional_edges(\n",
        "    START,\n",
        "    route_question,\n",
        "    {\n",
        "        \"web_search\": \"web_search\",  # 웹 검색으로 라우팅\n",
        "        \"vectorstore\": \"retrieve\",  # 벡터스토어로 라우팅\n",
        "    },\n",
        ")\n",
        "workflow.add_edge(\"web_search\", \"generate\")  # 웹 검색 후 답변 생성\n",
        "workflow.add_edge(\"retrieve\", \"grade_documents\")  # 문서 검색 후 평가\n",
        "workflow.add_conditional_edges(\n",
        "    \"grade_documents\",\n",
        "    decide_to_generate,\n",
        "    {\n",
        "        \"transform_query\": \"transform_query\",  # 쿼리 변환 필요\n",
        "        \"generate\": \"generate\",  # 답변 생성 가능\n",
        "    },\n",
        ")\n",
        "workflow.add_edge(\"transform_query\", \"retrieve\")  # 쿼리 변환 후 문서 검색\n",
        "workflow.add_conditional_edges(\n",
        "    \"generate\",\n",
        "    hallucination_check,\n",
        "    {\n",
        "        \"hallucination\": \"generate\",  # Hallucination 발생 시 재생성\n",
        "        \"relevant\": END,  # 답변의 관련성 여부 통과\n",
        "        \"not relevant\": \"transform_query\",  # 답변의 관련성 여부 통과 실패 시 쿼리 변환\n",
        "    },\n",
        ")\n",
        "\n",
        "# 그래프 컴파일\n",
        "app = workflow.compile(checkpointer=MemorySaver())"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "748f4505",
      "metadata": {},
      "source": [
        "그래프를 시각화 합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "46ce79fe",
      "metadata": {},
      "outputs": [],
      "source": [
        "from langchain_teddynote.graphs import visualize_graph\n",
        "\n",
        "visualize_graph(app)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "3fd2739b",
      "metadata": {},
      "source": [
        "## Execute Graph\n",
        "\n",
        "**그래프 사용** 단계에서는 **Adaptive RAG**의 실행을 통해 쿼리 처리 결과를 확인합니다. 이 과정은 그래프의 각 노드와 엣지를 따라 쿼리를 처리하여 최종 결과를 생성합니다.\n",
        "\n",
        "- **그래프 실행**: 정의된 그래프를 실행하여 쿼리의 흐름을 따라갑니다.\n",
        "- **결과 확인**: 그래프 실행 후 생성된 결과를 검토하여 쿼리가 적절히 처리되었는지 확인합니다.\n",
        "- **결과 분석**: 생성된 결과를 분석하여 쿼리의 목적에 부합하는지 평가합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b020b140",
      "metadata": {},
      "outputs": [],
      "source": [
        "from langchain_teddynote.messages import stream_graph, random_uuid\n",
        "from langchain_core.runnables import RunnableConfig\n",
        "\n",
        "# config 설정(재귀 최대 횟수, thread_id)\n",
        "config = RunnableConfig(recursion_limit=20, configurable={\"thread_id\": random_uuid()})\n",
        "\n",
        "# 질문 입력\n",
        "inputs = {\n",
        "    \"question\": \"삼성전자가 개발한 생성형 AI 의 이름은?\",\n",
        "}\n",
        "\n",
        "# 그래프 실행\n",
        "stream_graph(app, inputs, config, [\"agent\", \"rewrite\", \"generate\"])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "e25d23b6",
      "metadata": {},
      "outputs": [],
      "source": [
        "# 질문 입력\n",
        "inputs = {\n",
        "    \"question\": \"2024년 노벨 문학상 수상자는 누구인가요?\",\n",
        "}\n",
        "\n",
        "# 그래프 실행\n",
        "stream_graph(app, inputs, config, [\"agent\", \"rewrite\", \"generate\"])"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "langchain-kr-lwwSZlnu-py3.11",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.10"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}